import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self,input_dim,hidden,output_dim):
        super(MLP,self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden, bias=True)
        self.fc3 = nn.Linear(hidden, output_dim, bias=True)

    def forward(self,x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x
    
class MLP_Scoring(nn.Module):
    def __init__(self,input_dim,hidden,output_dim, num_hidden_layers=0):
        super(MLP_Scoring,self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)# For expt9, bias=True)
        self.fc_out = nn.Sequential(self.fc1, nn.Sigmoid())
        # self.fc_out = self.fc1

    def forward(self,x,printit=False):
        x = self.fc_out(x)
        return x

class MLP_VEC(nn.Module):
    def __init__(self,input_dim,hidden,output_dim, num_hidden_layers=0):
        super(MLP_VEC,self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden)# For expt9, bias=True)
        self.nlnr = nn.ReLU()
        self.hidden_layers = nn.ModuleList([nn.Linear(hidden,hidden) for _ in range(num_hidden_layers)])
        self.fc_out = nn.Linear(hidden, output_dim)# For expt9, bias=True)

    def forward(self,x,printit=False):
        x = self.fc1(x)
        x = self.nlnr(x)
        for layer in self.hidden_layers:
            x = layer(x)
            x = self.nlnr(x)
        x = self.fc_out(x)
        return x